function[] = run_simulation_func(t, rho1, output_name, pre_sample_type, varargin)
% run_simulation_func - Monte Carlo simulation for ITS model
%
% INPUTS:
%   t               : number of periods (sample size)
%   rho1            : auto-regressive coefficient
%   output_name     : full path for output .mat file
%   pre_sample_type : 'short' (25%), 'even' (50%), or 'long' (75%)
%
% OPTIONAL NAME-VALUE PAIRS:
%   'num_replications'    : number of MC replications (default: 100)
%   'run_bootstrap'       : true/false to run bootstrap (default: false)
%   'num_bootstrap'       : number of bootstrap replications (default: 1000)
%   'inference_estimator' : 'Textbook-NW' or 'NW-fixed-b' (default: 'Textbook-NW')
%   'critical_value'      : 'Standard-normal' or 'Fixed-b' (default: 'Standard-normal')
%   'p_max'               : max lag for AR model selection (default: 8)
%   'burnin'              : burn-in periods for ARMA generation (default: 10000)
%   'df'                  : degrees of freedom for t-distributed innovations
%                           (default: Inf = Gaussian, use 5 for heavy tails, 15 for moderate)
%
% EXAMPLE:
%   run_simulation_func(200, 0.3, 'output.mat', 'even', ...
%                       'num_replications', 100000, 'run_bootstrap', false)
%   run_simulation_func(200, 0.3, 'output.mat', 'even', 'df', 5)  % Heavy-tailed errors

% Parse optional arguments
p_input = inputParser;
addParameter(p_input, 'num_replications', 100, @isnumeric);
addParameter(p_input, 'run_bootstrap', false, @islogical);
addParameter(p_input, 'num_bootstrap', 1000, @isnumeric);
addParameter(p_input, 'inference_estimator', 'Textbook-NW', @ischar);
addParameter(p_input, 'critical_value', 'Standard-normal', @ischar);
addParameter(p_input, 'p_max', 8, @isnumeric);
addParameter(p_input, 'burnin', 10000, @isnumeric);
addParameter(p_input, 'df', Inf, @isnumeric);  % degrees of freedom for t-distribution
addParameter(p_input, 'auto_seed', true, @islogical);
addParameter(p_input, 'base_seed', 8964, @isnumeric);
parse(p_input, varargin{:});

% Extract parameters
num_replications    = p_input.Results.num_replications;
run_bootstrap       = p_input.Results.run_bootstrap;
num_bootstrap       = p_input.Results.num_bootstrap;
inference_estimator = p_input.Results.inference_estimator;
critical_value      = p_input.Results.critical_value;
p_max               = p_input.Results.p_max;
burnin              = p_input.Results.burnin;
df                  = p_input.Results.df;
auto_seed           = p_input.Results.auto_seed;
manual_base_seed    = p_input.Results.base_seed;

% Fixed parameters
show_fig       = false;
asy_cov_matrix = true;
p              = 0.975;           % (1-p)*2*100% type-I error
n_power_test   = 21;              % Number of points for power tests
true_beta      = [1; 1; 1; 1];    % true parameter values
shift_size     = 0;               % for testing power
test_beta      = true_beta + [0, shift_size, 0, shift_size]';  % test parameter values

% Define the range for beta values for power testing
% Each parameter will vary within (true_beta - 1) to (true_beta + 1)
beta_for_power_test = arrayfun(@(b) linspace(b - 1, b + 1, n_power_test), ...
    true_beta, 'UniformOutput', false);
beta_for_power_test = cell2mat(beta_for_power_test);

% Determine pre-sample length based on user input
switch pre_sample_type
    case 'short'
        pre_sample_length = floor(5 * t / 20);   % 25% of t
    case 'even'
        pre_sample_length = floor(10 * t / 20);  % 50% of t
    case 'long'
        pre_sample_length = floor(15 * t / 20);  % 75% of t
    otherwise
        error('Invalid pre_sample_type. Choose from "short", "even", or "long".')
end
xi = pre_sample_length/t;

% ARMA parameters
theta1        = rho1;     % moving-average coefficient

% Replications
% Smart seed detection for multi-run support
if auto_seed
    % Try to extract run number from output path
    tokens = regexp(output_name, 'run(\d+)', 'tokens');
    if ~isempty(tokens)
        run_num = str2double(tokens{1}{1});
        base_seed = 8964 + run_num * 1000;
        fprintf('Multi-run mode detected: run%03d, using seed=%d\n', run_num, base_seed);
    else
        base_seed = 8964;  % Default seed for single-run mode
        fprintf('Single-run mode: using default seed=%d\n', base_seed);
    end
else
    % Manual seed mode
    base_seed = manual_base_seed;
    fprintf('Manual seed mode: using seed=%d\n', base_seed);
end
rng(base_seed); % set random seed

% Generate ARMA(1,1) errors using standalone function
% Supports Gaussian (df=Inf) or t-distributed (df=5,15,...) innovations
[yy_original, ~] = generate_arma_process_pq(t, burnin, num_replications, ...
    rho1, theta1, [], df);

disp(['t is ', num2str(t)]);
if isinf(df)
    disp('Innovation distribution: Gaussian');
else
    fprintf('Innovation distribution: t(%d)\n', df);
end

% Determine truncation parameter based on inference estimator
switch inference_estimator
    case 'NW-fixed-b'
        nw_lag_trunc = floor(1.3 * sqrt(t))+1;
    case 'Textbook-NW'
        nw_lag_trunc = floor(0.75 * t^(1/3))+1;
    otherwise % Default or unrecognized estimator
        kernel_size  = 6;    % Tsay use 6. Schwert (1989) use 4 and 12.
        nw_lag_trunc = floor(kernel_size * ((t / 100)^(1/4)))+1;
end

switch critical_value
    case 'Fixed-b'
        cv = fix_b_asymptotic_cv(1.30/sqrt(t), p);
    case 'Standard-normal'
        cv = icdf('Normal', p, 0, 1);
end

disp(['NW HAR estimator truncation parameter is ', num2str(nw_lag_trunc)]);

% Generate design matrix for the regression
% First column: Intercept
% Second column: Time trend (1, 2, 3, ..., t)
% Third column: Binary variable for post-treatment periods
% Fourth column: Time trend for post-treatment periods only
x = [ones(t, 1), (1:t)', ...
    [zeros(pre_sample_length, 1); ones(t - pre_sample_length, 1)], ...
    [zeros(pre_sample_length, 1); (1:(t - pre_sample_length))']];

% Normalize the time trend by scaling its values to the range [0, 1].
nx = x;
nx(:, 2) = nx(:, 2)./t;
nx(:, 4) = nx(:, 4)./(1-xi)./t;

% Generate the y values (observations) based on the design matrix and true parameters
% Adding the ARMA(1,1) noise (yy_original)
yy = yy_original + x * true_beta;

% True covariance matrix
if asy_cov_matrix == true
    Xi_11 = [1, 1/2; 1/2, 1/3];
    Xi_12 = [sqrt(1-xi), sqrt(1-xi)/2; sqrt(1-xi)*(1+xi)/2, sqrt(1-xi)*(2+xi)/6];
    Xi_21 = Xi_12';
    Xi_22 = Xi_11;
else
    Xi_11 = [ 1, (t+1)/(2*t); (t+1)/(2*t), ((2*t+1)*(t+1))/(6*t^2)];
    Xi_12 = [sqrt(1-xi), (1+(1-xi)*t)/(2*t*(1-xi)^(1/2));...
        ((1-xi)^(1/2)*((1+xi)*t+1))/(2*t), ...
        (1+(1-xi)*(1+xi)*t^2+3*t)...
        /(6*t^2*(1-xi)^(1/2))];
    Xi_21 = Xi_12';
    Xi_22 = [1, ((1-xi)*t+1)/(2*(1-xi)*t); ((1-xi)*t+1)/(2*(1-xi)*t), ...
        (2*(1-xi)^2*t^2+3*(1-xi)*t+1)/(6*t^2*(1-xi)^2)];
end

cov_mat = [Xi_11, Xi_12; Xi_21, Xi_22];

% Normalize covariance
norm_mat = [sqrt(t), t^(3/2), sqrt(t*(1-xi)), (t*(1-xi))^(3/2)];

%==========================================================================
% Compare the asymptotic covariance matrix with Λ^(-1)X'XΛ^(-1)
% This comparison checks how closely the asymptotic covariance matrix 
% approximates the actual one

% Display a message explaining the matrix comparison
disp('Difference between asymptotic covariance matrix and normalized X''X matrix:')
disp(cov_mat - diag(norm_mat)\x'*x/diag(norm_mat))
%==========================================================================

% Initialize results storage with preallocation
est_ols                    = zeros(num_replications, length(true_beta));
test_ols                   = zeros(num_replications, length(true_beta));
test_ols_unadjusted        = zeros(num_replications, length(true_beta));
test_ols_normalized        = zeros(num_replications, length(true_beta));
long_run_variance          = zeros(num_replications, 1);

% Bootstrap-specific storage (only if needed)
if run_bootstrap
    test_ols_outcome_power_aic = zeros(num_replications, length(true_beta), 21);
    test_ols_outcome_power_bic = zeros(num_replications, length(true_beta), 21);
    test_ols_outcome_aic       = zeros(num_replications, length(true_beta));
    test_ols_outcome_bic       = zeros(num_replications, length(true_beta));
end

% Initialize parallel pool if available
use_parallel = true;
try
    poolobj = gcp('nocreate'); % Check if pool exists
    if isempty(poolobj)
        poolobj = parpool('local'); % Create local pool
    end
    fprintf('Using parallel pool with %d workers\n', poolobj.NumWorkers);
    % Set per-worker RNG streams to ensure reproducibility when using parfor
    try
        spmd
            s = RandStream.create('mrg32k3a', 'NumStreams', poolobj.NumWorkers, 'StreamIndices', labindex, 'Seed', 8964);
            RandStream.setGlobalStream(s);
        end
    catch
        % If spmd fails, we'll fallback to seeding inside each iteration
    end
catch
    warning('Parallel Computing Toolbox not available. Running serial.');
    use_parallel = false;
end

% Main simulation loop
if use_parallel
    parfor iteration = 1:num_replications
        y = yy(:, iteration);
        % If spmd RNG failed, fallback to per-iteration RNG seeding so each
        % iter is still reproducible across runs. This is a reasonable
        % compromise for environments without spmd support.
        try
            rng(base_seed + iteration);
        catch
            % If RNG fails, ignore (rare)
        end
        
        % Fit OLS model
        [~, ests, residuals, tests, mdl] = compute_ols_inference(...
            y, x, xi, nw_lag_trunc, beta_for_power_test, true);

        est_ols(iteration, :) = ests';

        % === BOOTSTRAP SECTION (only if run_bootstrap = true) ===
        % Pre-declare local tests to avoid uninitialized temporary warnings in parfor
        local_test_aic = zeros(length(true_beta), 2);
        local_test_bic = zeros(length(true_beta), 2);
        if run_bootstrap
            % local_test_* are used in parfor to avoid temporary variable
            % warnings and ensure each iteration has its own local copy.
            local_test_aic = zeros(length(true_beta), 2);
            local_test_bic = zeros(length(true_beta), 2);
            % Fit AR model on the residuals and select the best AR model
            % based on AIC and BIC criteria.
            [best_p_aic, best_model_aic, best_p_bic, best_model_bic, ~, ~] = ...
                 choose_best_ar_model(residuals, p_max);

            % Infer residuals from the selected AR models.
            AR_resid_aic = infer(best_model_aic, residuals);
            AR_resid_bic = infer(best_model_bic, residuals);

            % Demean the residuals from AR process
            AR_resid_aic = AR_resid_aic - sum(AR_resid_aic)./(t-best_p_aic);
            AR_resid_bic = AR_resid_bic - sum(AR_resid_bic)./(t-best_p_bic);

            % Generate bootstrap errors
            [~, AR_resid_aic_star] = generate_arma_process_pq(t, burnin, ...
                num_bootstrap, cell2mat(best_model_aic.AR), [], AR_resid_aic);
            [~, AR_resid_bic_star] = generate_arma_process_pq(t, burnin, ...
                num_bootstrap, cell2mat(best_model_bic.AR), [], AR_resid_bic);

            % Generate bootstrap observations
            y_hat = x*est_ols(iteration, :)';
            y_hat_aic = y_hat + AR_resid_aic_star;
            y_hat_bic = y_hat + AR_resid_bic_star;


            % Compute bootstrap-t values
            bootstrap_t_values_aic = zeros(length(true_beta), num_bootstrap);
            bootstrap_t_values_bic = zeros(length(true_beta), num_bootstrap);

            for iter_bootstrap = 1:num_bootstrap
                [~, ~, ~, bootstrap_t_values_aic(:, iter_bootstrap), ~] = ...
                    compute_ols_inference(y_hat_aic(:, iter_bootstrap), x, xi, ...
                    nw_lag_trunc, est_ols(iteration, :)', true);
                [~, ~, ~, bootstrap_t_values_bic(:, iter_bootstrap), ~] = ...
                    compute_ols_inference(y_hat_bic(:, iter_bootstrap), x, xi, ...
                    nw_lag_trunc, est_ols(iteration, :)', true);
            end

            % Sort the bootstrap estimates of the long-run variance for AIC and BIC
            bootstrap_t_values_aic = sort(bootstrap_t_values_aic, 2);
            bootstrap_t_values_bic = sort(bootstrap_t_values_bic, 2);

            % Calculate the indices for the (1-p) and p percentiles
            idx_low = ceil((1 - p) * num_bootstrap);
            idx_high = ceil(p * num_bootstrap);

            % Store the (1-p) and p percentiles in the respective matrices
            local_test_aic = bootstrap_t_values_aic(:, [idx_low, idx_high]);
            local_test_bic = bootstrap_t_values_bic(:, [idx_low, idx_high]);
        end
        % === END BOOTSTRAP SECTION ===

        mdl_normalized = fitlm(nx, y, 'intercept', false);
        
        % Test statistics
        test_ols(iteration, :) = tests(:, ceil(n_power_test / 2))';  % Select the test for true beta
        test_ols_unadjusted(iteration, :) = ((est_ols(iteration, :)' - test_beta)...
            ./diag(sqrt(hac(mdl, Intercept = false, ...
            Type = "HAC", Weight = "BT", Display="off", Bandwidth = nw_lag_trunc, ...
            SmallT = true))))';

        test_ols_normalized(iteration, :) = ((est_ols(iteration, :)' - test_beta)...
            ./[1; 1/t; 1; 1/(1-xi)/t]...
            ./diag(sqrt(hac(mdl_normalized, Intercept = false, ...
            Type = "HAC", Weight = "BT", Display="off", Bandwidth = nw_lag_trunc, ...
            SmallT = true))))';

        % Bootstrap test outcomes (only if run_bootstrap = true)
        if run_bootstrap
              test_ols_outcome_aic(iteration, :) = (test_ols(iteration, :) >= local_test_aic(:, 2)') ...
                  | (test_ols(iteration, :) <= local_test_aic(:, 1)');
              test_ols_outcome_bic(iteration, :) = (test_ols(iteration, :) >= local_test_bic(:, 2)') ...
                 | (test_ols(iteration, :) <= local_test_bic(:, 1)');

            for iter_power_test = 1:21
                test_ols_outcome_power_aic(iteration, :, iter_power_test)...
                    = (tests(:, iter_power_test)' >= local_test_aic(:, 2)') ...
                 | (tests(:, iter_power_test)' <= local_test_aic(:, 1)');
                test_ols_outcome_power_bic(iteration, :, iter_power_test)...
                    = (tests(:, iter_power_test)' >= local_test_bic(:, 2)') ...
                | (tests(:, iter_power_test)' <= local_test_bic(:, 1)');
            end
        end

    end
else
    for iteration = 1:num_replications
        y = yy(:, iteration);
        
        % Fit OLS model
        [~, ests, residuals, tests, mdl] = compute_ols_inference(...
            y, x, xi, nw_lag_trunc, beta_for_power_test, true);

        est_ols(iteration, :) = ests';

        % === BOOTSTRAP SECTION (only if run_bootstrap = true) ===
        if run_bootstrap
            % Fit AR model on the residuals and select the best AR model
            % based on AIC and BIC criteria.
            [best_p_aic, best_model_aic, best_p_bic, best_model_bic, ~, ~] = ...
                 choose_best_ar_model(residuals, p_max);

            % Infer residuals from the selected AR models.
            AR_resid_aic = infer(best_model_aic, residuals);
            AR_resid_bic = infer(best_model_bic, residuals);

            % Demean the residuals from AR process
            AR_resid_aic = AR_resid_aic - sum(AR_resid_aic)./(t-best_p_aic);
            AR_resid_bic = AR_resid_bic - sum(AR_resid_bic)./(t-best_p_bic);

            % Generate bootstrap errors
            [~, AR_resid_aic_star] = generate_arma_process_pq(t, burnin, ...
                num_bootstrap, cell2mat(best_model_aic.AR), [], AR_resid_aic);
            [~, AR_resid_bic_star] = generate_arma_process_pq(t, burnin, ...
                num_bootstrap, cell2mat(best_model_bic.AR), [], AR_resid_bic);

            % Generate bootstrap observations
            y_hat = x*est_ols(iteration, :)';
            y_hat_aic = y_hat + AR_resid_aic_star;
            y_hat_bic = y_hat + AR_resid_bic_star;


            % Compute bootstrap-t values
            bootstrap_t_values_aic = zeros(length(true_beta), num_bootstrap);
            bootstrap_t_values_bic = zeros(length(true_beta), num_bootstrap);

            for iter_bootstrap = 1:num_bootstrap
                [~, ~, ~, bootstrap_t_values_aic(:, iter_bootstrap), ~] = ...
                    compute_ols_inference(y_hat_aic(:, iter_bootstrap), x, xi, ...
                    nw_lag_trunc, est_ols(iteration, :)', true);
                [~, ~, ~, bootstrap_t_values_bic(:, iter_bootstrap), ~] = ...
                    compute_ols_inference(y_hat_bic(:, iter_bootstrap), x, xi, ...
                    nw_lag_trunc, est_ols(iteration, :)', true);
            end

            % Sort the bootstrap estimates of the long-run variance for AIC and BIC
            bootstrap_t_values_aic = sort(bootstrap_t_values_aic, 2);
            bootstrap_t_values_bic = sort(bootstrap_t_values_bic, 2);

            % Calculate the indices for the (1-p) and p percentiles
            idx_low = ceil((1 - p) * num_bootstrap);
            idx_high = ceil(p * num_bootstrap);

            % Store the (1-p) and p percentiles in the respective matrices
            test_aic = bootstrap_t_values_aic(:, [idx_low, idx_high]);
            test_bic = bootstrap_t_values_bic(:, [idx_low, idx_high]);
        end
        % === END BOOTSTRAP SECTION ===

        mdl_normalized = fitlm(nx, y, 'intercept', false);
        
        % Test statistics
        test_ols(iteration, :) = tests(:, ceil(n_power_test / 2))';  % Select the test for true beta
        test_ols_unadjusted(iteration, :) = ((est_ols(iteration, :)' - test_beta)...
            ./diag(sqrt(hac(mdl, Intercept = false, ...
            Type = "HAC", Weight = "BT", Display="off", Bandwidth = nw_lag_trunc, ...
            SmallT = true))))';

        test_ols_normalized(iteration, :) = ((est_ols(iteration, :)' - test_beta)...
            ./[1; 1/t; 1; 1/(1-xi)/t]...
            ./diag(sqrt(hac(mdl_normalized, Intercept = false, ...
            Type = "HAC", Weight = "BT", Display="off", Bandwidth = nw_lag_trunc, ...
            SmallT = true))))';

        % Bootstrap test outcomes (only if run_bootstrap = true)
        if run_bootstrap
            test_ols_outcome_aic(iteration, :) = (test_ols(iteration, :) >= test_aic(:, 2)') ...
                 | (test_ols(iteration, :) <= test_aic(:, 1)');
            test_ols_outcome_bic(iteration, :) = (test_ols(iteration, :) >= test_bic(:, 2)') ...
                | (test_ols(iteration, :) <= test_bic(:, 1)');

            for iter_power_test = 1:21
                test_ols_outcome_power_aic(iteration, :, iter_power_test)...
                    = (tests(:, iter_power_test)' >= test_aic(:, 2)') ...
                 | (tests(:, iter_power_test)' <= test_aic(:, 1)');
                test_ols_outcome_power_bic(iteration, :, iter_power_test)...
                    = (tests(:, iter_power_test)' >= test_bic(:, 2)') ...
                | (tests(:, iter_power_test)' <= test_bic(:, 1)');
            end
        end

    end
end

% Output results
disp('#Replication: '), disp(num_replications);
disp('Sample size: '), disp(t);
test_ols_outcome = abs(test_ols) > cv;
test_ols_outcome_unadjusted = abs(test_ols_unadjusted) > cv;
test_ols_outcome_normalized = abs(test_ols_normalized) > cv;
disp('Mean of long run variance: '), disp(mean(long_run_variance));


% Create output table
if run_bootstrap
    output = [true_beta';...
        round(mean(est_ols), 4); ...
        round(true_beta' - mean(est_ols), 4);...
        round(rmse(est_ols, true_beta'), 4); ...
        100 * mean(test_ols_outcome); ...
        100 * mean(test_ols_outcome_unadjusted); ...
        100 * mean(test_ols_outcome_normalized); ...
        100 * mean(test_ols_outcome_aic); ...
        100 * mean(test_ols_outcome_bic)...
        ];

    power_T_aic = zeros(21, 4);
    power_T_bic = zeros(21, 4);
    for iter_power_test = 1:21
        power_T_aic(iter_power_test, :) = ...
            100 * mean(test_ols_outcome_power_aic(:, :, iter_power_test));
        power_T_bic(iter_power_test, :) = ...
            100 * mean(test_ols_outcome_power_bic(:, :, iter_power_test));
    end

    output_T = array2table(output, "RowNames", {'True parameter', 'Average estimates', ...
        'Average bias', 'RMSE', 'Rejection ratio', 'Rejection ratio (unadjusted)', ...
        'Rejection ratio (normalized)', ...
        'Rejection ratio bootstrap AIC', 'Rejection ratio bootstrap BIC'
        }, ...
        "VariableNames", {'b0', 'b1', 'b2', 'b3'});
else
    % Without bootstrap
    output = [true_beta';...
        round(mean(est_ols), 4); ...
        round(true_beta' - mean(est_ols), 4);...
        round(rmse(est_ols, true_beta'), 4); ...
        100 * mean(test_ols_outcome); ...
        100 * mean(test_ols_outcome_unadjusted); ...
        100 * mean(test_ols_outcome_normalized)...
        ];

    output_T = array2table(output, "RowNames", {'True parameter', 'Average estimates', ...
        'Average bias', 'RMSE', 'Rejection ratio', 'Rejection ratio (unadjusted)', ...
        'Rejection ratio (normalized)'
        }, ...
        "VariableNames", {'b0', 'b1', 'b2', 'b3'});
end

disp(output_T)

if show_fig == true

true_y = x*true_beta;
std_e  = std(yy_original, 0, 2);
hold on
plot(1:pre_sample_length, true_y(1:pre_sample_length), ...
    LineWidth= 2.5, Color= 'blue')
hold on
plot(1:pre_sample_length, true_y(1:pre_sample_length)+std_e(1:pre_sample_length), ...
    LineWidth= 0.5, Color= 'blue')
hold on
plot(1:pre_sample_length, true_y(1:pre_sample_length)-std_e(1:pre_sample_length), ...
    LineWidth= 0.5, Color= 'blue')
hold on
plot(pre_sample_length+1, 1:true_y(t), 'r.')
hold on
plot(pre_sample_length+1:t, true_y(pre_sample_length+1:t), ...
    LineWidth= 2.5, Color= 'blue')
hold on
plot(pre_sample_length+1:t, ...
    true_y(pre_sample_length+1:t)+std_e(pre_sample_length+1:t), ...
    LineWidth= 0.5, Color= 'blue')
hold on
plot(pre_sample_length+1:t, ...
    true_y(pre_sample_length+1:t)-std_e(pre_sample_length+1:t), ...
    LineWidth= 0.5, Color= 'blue')

end

% Save results (conditionally include bootstrap outputs)
if run_bootstrap
    save(output_name, "output_T", "power_T_aic", "power_T_bic")
else
    save(output_name, "output_T")
end





% Note: ARMA process generation is now handled by the standalone function
% generate_arma_process_pq.m which supports:
%   - General ARMA(p,q) processes
%   - Gaussian or t-distributed innovations
%   - Bootstrap resampling

end  % End of run_simulation_func